import torch
import matplotlib.pyplot as plt
import random
import copy
import torch.optim.lr_scheduler as lr_scheduler
from Trace import Covariance
from Visualization import funcaverage

def GD(input, LossFunctions, eps, lr, decay_rate):
    x = copy.deepcopy(input)
    AccCov = torch.zeros(2, 2)
    optimizer = torch.optim.SGD([x], lr=lr)
    lmbda = lambda epoch: decay_rate ** epoch
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda)
    PositiveHits = 0
    Traj = torch.zeros(2, eps + 1)
    #S = list(range(len(LossFunctions)))
    for ep in range(eps):
        Traj[:, ep] = x.detach()
        optimizer.zero_grad()
        # Subsample the loss functions and construct the loss
        loss = 0
        for k in range(len(LossFunctions)):
            #print(k)
            loss += LossFunctions[k](x)
        #print(x)
        loss /= len(LossFunctions)
        loss.backward()
        optimizer.step()
        scheduler.step()
        #AccCov += Covariance(x, LossFunctions) * (optimizer.param_groups[0]["lr"] ** 2)
        if x[0] > 0:
            PositiveHits += 1
    loss = funcaverage(LossFunctions)
    #grad = torch.ones_like(x)
    #optimizer.zero_grad()
    #grad = torch.autograd.grad(loss(x), x)
    #print('grad is {}'.format(grad))
    #print('value in the x.grad is {}'.format(x.grad))
    # No use for H
    H = AccCov
    #H = torch.autograd.functional.hessian(loss, x)
    Traj[:, eps] = x.detach()
    #print("The lr of the final ite is {}".format(optimizer.param_groups[0]["lr"]))
    return x, PositiveHits / eps, Traj, H, AccCov